# https://docs.openvino.ai/2024/notebooks/vision-monodepth-with-output.html
# https://colab.research.google.com/github/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/vision-monodepth/vision-monodepth.ipynb#scrollTo=1c9f693b
import time

import openvino as ov
import cv2
import numpy
import torch
import torch.nn as nn
from torchvision import models
from torch.nn import functional as F
from torchvision.transforms import transforms


def get_kernel(kernel_len=16, nsig=10):  # nsig 标准差 ，kernlen=16核尺寸
    GaussianKernel = cv2.getGaussianKernel(kernel_len, nsig) \
                     * cv2.getGaussianKernel(kernel_len, nsig).T
    return GaussianKernel


class Gaussian_kernel(nn.Module):
    def __init__(self,
                 kernel_len, nsig=20):
        super(Gaussian_kernel, self).__init__()
        self.kernel_len = kernel_len
        kernel = get_kernel(kernel_len=kernel_len, nsig=nsig)  # 获得高斯卷积核
        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)  # 扩展两个维度
        self.weight = nn.Parameter(data=kernel, requires_grad=False)

        self.padding = torch.nn.ReplicationPad2d(int(self.kernel_len / 2))

    def forward(self, x):  # x1是用来计算attention的，x2是用来计算的Cs
        x = self.padding(x)
        # 对三个channel分别做卷积
        res = []
        for i in range(x.shape[1]):
            res.append(F.conv2d(x[:, i, :, :], self.weight))
        x_output = torch.cat(res, dim=0)
        return x_output


class DensenetGrade:
    def __init__(self, pth_path: str):
        core = ov.Core()
        cpu_model = core.read_model(pth_path)
        self.model = core.compile_model(model=cpu_model, device_name="CPU")
        self.output_key = self.model.output(0)

        self.result_name = "Densenet121Grade"

        self.gaussian_kernel = Gaussian_kernel(11, 30)
        self.transform_kernel = transforms.Compose([
            transforms.ToTensor()
        ])

    def softmax(self, x):
        f_x = numpy.exp(x) / numpy.sum(numpy.exp(x))
        return f_x

    @torch.no_grad()
    def infer(self, image_path: str):
        trans_dim = self._preprocess(image_path)
        result = self.model([trans_dim.numpy()])[self.output_key]
        y = self.softmax(result[0]).tolist()
        return self._postprocess(y)

    def _postprocess(self, pred_arr):
        return {
            'index': pred_arr.index(max(pred_arr)),
            'pie': pred_arr
        }

    def _preprocess(self, data_in: str):
        img_arr = self.circle_crop(data_in) / 255.0
        x = torch.from_numpy(img_arr[:, :, ::-1].astype(numpy.float32).transpose((2, 0, 1))).unsqueeze(0)
        return x

    def circle_crop(self, image_src: str):
        crop_mask = self.crop_image_from_mask(image_src)
        return self.crop_image_with_gaussian(crop_mask)

    def crop_image_from_mask(self, image_src: str):
        # load
        image = cv2.imread(image_src)

        # binary
        gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        _, binary_image = cv2.threshold(gray_image, 7, 255, cv2.THRESH_BINARY)

        # cal roi
        x, y, w, h = cv2.boundingRect(binary_image)
        center = (w // 2), (h // 2)
        radius = min(center)
        y = y + center[1] - radius
        x = x + center[0] - radius
        copy_image = image[y: y + 2 * radius, x: x + 2 * radius]

        # gen mask
        mask = numpy.zeros_like(copy_image)
        cv2.circle(mask, (radius, radius), radius, (1, 1, 1), -1)

        # exposure
        return copy_image * mask

    def crop_image_with_gaussian(self, data_in: numpy.ndarray):
        ori_image = cv2.resize(data_in, (224, 224)).astype(numpy.float32)
        with torch.no_grad():
            image_cuda = self.transform_kernel(ori_image).unsqueeze(0)
            out = numpy.transpose(self.gaussian_kernel(image_cuda).cpu().numpy(), (1, 2, 0))

        if out.shape != (224, 224, 3):
            out = out[0: 224, 0: 224]
        exposure = cv2.addWeighted(ori_image, 4, out, -4, 128)
        exposure = numpy.clip(exposure, 0, 255).astype(numpy.uint8)
        exposure = cv2.cvtColor(exposure, cv2.COLOR_BGR2RGB)
        return exposure


if __name__ == '__main__':
    # fp32 {'index': 2, 'pie': [0.1433304101228714, 0.05587254837155342, 0.7058674693107605, 0.08943081647157669, 0.0054987091571092606]}
    grade = DensenetGrade("export_dense121_cpu.xml")
    t1 = time.perf_counter()
    print(grade.infer("1.jpg"))
    t2 = time.perf_counter()
    print(t2 - t1)
    # cv2.imwrite("2-1.jpg", grade.crop_image_from_mask("1-1.jpg"))
    # cv2.imwrite("2-2.jpg",  grade.circle_crop("1-1.jpg"))